/*
* Copyright (c) 2012-2013 Spotify AB
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.spotify.netty4.handler.codec.zmtp;
import com.google.common.util.concurrent.SettableFuture;
import org.junit.After;
import org.junit.Before;
import org.junit.experimental.theories.Theories;
import org.junit.experimental.theories.Theory;
import org.junit.experimental.theories.suppliers.TestedOn;
import org.junit.runner.RunWith;
import java.net.InetSocketAddress;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.ReferenceCountUtil;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPProtocols.ZMTP20;
import static com.spotify.netty4.handler.codec.zmtp.ZMTPSocketType.ROUTER;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertFalse;
@RunWith(Theories.class)
public class ProtocolViolationTests {
private Channel serverChannel;
private InetSocketAddress serverAddress;
private final String identity = "identity";
private NioEventLoopGroup bossGroup;
private NioEventLoopGroup group;
@ChannelHandler.Sharable
private static class MockHandler extends ChannelInboundHandlerAdapter {
private SettableFuture<Void> active = SettableFuture.create();
private SettableFuture<Throwable> exception = SettableFuture.create();
private SettableFuture<Void> inactive = SettableFuture.create();
private volatile boolean handshaked;
private volatile boolean read;
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
active.set(null);
}
@Override
public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
inactive.set(null);
}
@Override
public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception {
if (evt instanceof ZMTPHandshakeSuccess) {
handshaked = true;
}
}
@Override
public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception {
ReferenceCountUtil.release(msg);
read = true;
}
@Override
public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause)
throws Exception {
exception.set(cause);
ctx.close();
}
}
private final MockHandler mockHandler = new MockHandler();
@Before
public void setup() {
final ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap.channel(NioServerSocketChannel.class);
bossGroup = new NioEventLoopGroup(1);
group = new NioEventLoopGroup();
serverBootstrap.group(bossGroup, group);
serverBootstrap.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(final NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(
ZMTPCodec.builder()
.protocol(ZMTP20)
.socketType(ROUTER)
.localIdentity(identity)
.build(),
mockHandler);
}
});
serverChannel = serverBootstrap.bind(new InetSocketAddress("localhost", 0))
.awaitUninterruptibly().channel();
serverAddress = (InetSocketAddress) serverChannel.localAddress();
}
@After
public void teardown() {
if (serverChannel != null) {
serverChannel.close();
}
if (bossGroup != null) {
bossGroup.shutdownGracefully();
}
if (group != null) {
group.shutdownGracefully();
}
}
@Theory
public void protocolErrorsCauseException(
@TestedOn(ints = {16, 17, 27, 32, 48, 53}) final int payloadSize) throws Exception {
final Bootstrap b = new Bootstrap();
b.group(new NioEventLoopGroup());
b.channel(NioSocketChannel.class);
b.handler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(final NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(new MockHandler());
}
});
final Channel channel = b.connect(serverAddress).awaitUninterruptibly().channel();
final ByteBuf payload = Unpooled.buffer(payloadSize);
for (int i = 0; i < payloadSize; i++) {
payload.writeByte(0);
}
channel.writeAndFlush(payload);
mockHandler.active.get(5, SECONDS);
mockHandler.exception.get(5, SECONDS);
mockHandler.inactive.get(5, SECONDS);
assertFalse(mockHandler.handshaked);
assertFalse(mockHandler.read);
}
}